// RUN: stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify-diagnostics %s > %t.mlir
// RUN: stablehlo-translate --interpret --split-input-file %t.mlir

func.func @ragged_dot_mode_1() {
  %lhs = stablehlo.constant dense<
    [
      [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ],
      [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ],
      [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ],
      [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ],
      [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ],
      [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ],
      [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ],
      [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ],
      [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ],
      [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ],
      [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ]
    ]> : tensor<11x5xf32>
  %rhs = stablehlo.constant dense<[
    [
      [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ],
      [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ],
      [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ],
      [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ],
      [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ]
    ],
    [
      [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ],
      [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ],
      [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ],
      [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ],
      [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ]
    ],
    [
      [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ],
      [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ],
      [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ],
      [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ],
      [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ]
    ]]> : tensor<3x5x7xf32>
  %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64>
  %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
    ragged_dot_dimension_numbers = #chlo.ragged_dot<
      lhs_batching_dimensions = [],
      rhs_batching_dimensions = [],
      lhs_contracting_dimensions = [1],
      rhs_contracting_dimensions = [1],
      lhs_ragged_dimensions = [0],
      rhs_group_dimensions = [0]
    >,
    precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
  } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
  check.expect_almost_eq_const %result, dense<[
    [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822],
    [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687],
    [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926],
    [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318],
    [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02],
    [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222],
    [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071],
    [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069],
    [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162],
    [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938],
    [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561]
  ]> : tensor<11x7xf32>
  func.return
}

// -----

func.func @ragged_dot_mode_1_batching() {
  %lhs = stablehlo.constant dense<[
    [
      [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ],
      [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ],
      [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ],
      [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ],
      [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ],
      [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ],
      [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ],
      [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ],
      [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ],
      [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ],
      [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ]
    ],
    [
      [ -0.0500478409, 0.0459552184, 0.16929689, 0.172762454, -0.0818307 ],
      [ 0.171395928, 0.0513568744, 0.0548876, -0.00429011881, 0.195992649 ],
      [ 0.0481930152, -0.0201566443, -0.0727801323, 0.184329301, -0.0778752789 ],
      [ 0.0502121374, 0.0152426511, -0.0168754607, 0.174145252, 0.0589242205 ],
      [ 0.0393337533, 0.182294011, -0.0849748, 0.128454268, 0.131061375 ],
      [ 0.148345202, -0.0623903871, -0.0952396914, 0.10653659, 0.160474151 ],
      [ 0.0888630375, 0.120867364, 0.117623605, 0.199837387, 0.166571677 ],
      [ -0.0300415382, -0.00810345262, 0.00530457497, 0.0539821163, 0.0773340687 ],
      [ 0.153794467, 0.0236242339, 0.152453214, -0.0192048177, 0.0246183872 ],
      [ 0.0611911938, 0.0403752252, -0.013836287, -0.0465016849, -0.053884007 ],
      [ 0.0714964494, 0.140721709, -0.0900838748, 0.0603349432, 0.0495440438 ]
    ]]> : tensor<2x11x5xf32>
  %rhs = stablehlo.constant dense<[
    [
      [
        [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ],
        [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ],
        [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ],
        [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ],
        [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ]
      ],
      [
        [ 0.0226300061, -0.0574540682, 0.0694696084, -0.0243620798, 0.0465543643, 0.0392091647, 0.188328564 ],
        [ -0.0621907599, -0.0400728397, -0.0042250976, 0.0887807682, -0.0619863532, 0.0953761414, 0.0864902064 ],
        [ 0.140921891, -0.0256474689, 0.0429295525, 0.0167942569, -0.0390249, -0.0914874449, 0.170502067 ],
        [ 0.0279492214, -0.0573936924, 0.184246033, 0.0230939165, -0.060643442, 0.165694535, -0.0723479092 ],
        [ -0.051340431, -0.0786809325, 0.00960171223, -0.0240827873, -0.059467189, 0.134945959, 0.0365921929 ]
      ]
    ],
    [
      [
        [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ],
        [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ],
        [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ],
        [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ],
        [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ]
      ],
      [
        [ 0.11469271, 0.140216112, 0.111960642, 0.122514777, -0.0942722782, 0.165809333, 0.0574962273 ],
        [ 0.0389968231, -0.08044184, 0.114026703, 0.0466829464, 0.100303732, 0.104614742, -0.0401335768 ],
        [ 0.174990177, 0.159764826, 0.167005628, 0.0631844923, -0.0582415, 0.0351042375, 0.196808755 ],
        [ -0.035340406, 0.0338070318, -0.00528027117, 0.0543978438, 0.164451241, 0.0319176689, 0.0402595326 ],
        [ 0.141994983, 0.00954742, -0.0365443081, 0.199735016, -0.053918656, 0.0891464874, 0.0849051103 ]
      ]
    ],
    [
      [
        [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ],
        [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ],
        [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ],
        [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ],
        [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ]
      ],
      [
        [ 0.145917833, -0.0590635166, 0.0194431096, 0.0803030357, -0.0469358861, 0.148506433, -0.0526806451 ],
        [ 0.196381122, -0.0228494033, -0.0299202427, -0.069508791, -0.0341768041, 0.0904152468, 0.108802207 ],
        [ 0.138430953, 0.108872853, 0.125882119, 0.100856192, 0.0900289789, -0.0830678046, 0.0794649944 ],
        [ -0.0318976864, -0.00436662883, 0.109950341, -0.0647689179, 0.128771216, 0.0578369871, 0.0661734 ],
        [ 0.0763966814, -0.00110008568, 0.110896833, -0.057086423, -0.0514936894, 0.0455975607, 0.158067733 ]
      ]
    ]]> : tensor<3x2x5x7xf32>
  %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64>
  %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
    ragged_dot_dimension_numbers = #chlo.ragged_dot<
      lhs_batching_dimensions = [0],
      rhs_batching_dimensions = [1],
      lhs_contracting_dimensions = [2],
      rhs_contracting_dimensions = [2],
      lhs_ragged_dimensions = [1],
      rhs_group_dimensions = [0]
    >,
    precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
  } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
  check.expect_almost_eq_const %result, dense<[
    [
      [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822],
      [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687],
      [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926],
      [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318],
      [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02],
      [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222],
      [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071],
      [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069],
      [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162],
      [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938],
      [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561]
    ],
    [
      [0.0288968664, -0.00678509939, 0.0346419513, 0.0141028976, -0.017396003, 0.00451522879, 0.00792134088],
      [-0.0017626211, -0.0284877941, 0.0151375476, -0.00351338694, -0.00874114502, 0.0323345512, 0.0535612516],
      [0.00123786228, -0.00454656407, 0.0335229039, 0.0019464466, -2.14070082E-4, 0.0266590156, -0.0212618597],
      [-3.47743975E-4, -0.017693948, 0.0353507064, 0.00244920771, -0.0120135043, 0.0417729542, -0.0025454592],
      [0.0108208582, -0.0171308704, 0.00553112756, 0.0411250815, 0.0335835591, 0.038393192, -0.00547906291],
      [0.0169365555, 0.0157370344, -0.0128378682, 0.0470919088, -0.00582840201, 0.0324328542, 0.010203423],
      [0.0520783663, 0.0298755895, 0.0362326317, 0.0681023895, 0.0207777359, 0.052735541, 0.0455959477],
      [0.00623999349, -1.49650674E-4, -0.00651274621, 0.0146591738, 0.00641800836, 0.00297434814, 0.00838128477],
      [0.0506783053, 0.00703135319, 0.0220930576, 0.0259224195, 0.001958607, 0.0123232938, 0.00920359604],
      [0.0123091843, -5.780780e-03, -0.0128484722, 0.00679983944, -0.00871101767, 0.0087406747, -0.0115246754],
      [0.0274577513, -0.0175638888, -0.00203213934, -0.0198616516, -0.0110571291, 0.0365728177, 0.0162097216]
    ]
  ]> : tensor<2x11x7xf32>
  func.return
}
